# -*- coding: utf-8 -*-
"""
拆分训练集和测试集，并在ImageSet/Main中生成train.txt和test.txt
生成yolo使用的train_img_path.txt和test_img_path.txt文件
可以仅抽选一定比例

"""

import os
import xml.etree.ElementTree as ET
import argparse


def split_train_test_set(root_dir, target_dir, test_fraction=0.1):
    """
    split dataset into test and train subsets. The script will create train.txt and test.txt
    that contain a list of files for VOC dataset, and train_img_path.txt/test_img_path.txt
    for yolo training. All the four files will be created in ImageSet/Main directory.

    :param test_fraction: the fraction of test set
    :param root_dir: VOC root directory
    :param target_dir: directory save result
    :return: None
    """
    train_txt_path = os.path.join(target_dir, "train.txt")
    test_txt_path = os.path.join(target_dir, "test.txt")
    yolo_train_img_path_txt = os.path.join(target_dir, 'train_img_path.txt')
    yolo_test_img_path_txt = os.path.join(target_dir, 'test_img_path.txt')

    if os.path.exists(train_txt_path):
        os.remove(train_txt_path)

    if os.path.exists(test_txt_path):
        os.remove(test_txt_path)

    if os.path.exists(yolo_train_img_path_txt):
        os.remove(yolo_train_img_path_txt)

    if os.path.exists(yolo_test_img_path_txt):
        os.remove(yolo_test_img_path_txt)

    train_txt = open(train_txt_path, "w")
    test_txt = open(test_txt_path, "w")
    yolo_train_img_path = open(yolo_train_img_path_txt, 'w')
    yolo_test_img_path = open(yolo_test_img_path_txt, 'w')

    xml_path = os.path.join(root_dir, 'Annotations')


    # 从已有标注数据中抽选
    select_percent = 1

    # 测试集合占比
    xml_num = len(os.listdir(xml_path))
    print('=========based on labels(xml): {} ============'.format(xml_num))

    train_num = 0
    test_num = 0
    total_processed_num = 0
    if test_fraction <= 1e-10:
        sample_interval = xml_num + 1
    else:
        sample_interval = int(1/test_fraction)
    select_interval = int(1/select_percent)
    print('select sample interval is: {}'.format(select_interval))
    print('test sample interval is: {}'.format(sample_interval))

    files = os.listdir(xml_path)
    for file in files:
        total_processed_num = total_processed_num + 1
        if total_processed_num % select_interval != 0:
            if total_processed_num % 100 == 0:
                print('completed {} / {}'.format(total_processed_num, xml_num))
            continue

        img_name = file.split('.xml')[0]
        tree = ET.parse(os.path.join(xml_path, file))
        root = tree.getroot()
        file_name = root.find('filename').text
        # file_name = img_name+".jpg"

        img_path = os.path.join(root_dir, 'JPEGImages', file_name)

        if total_processed_num % (sample_interval * select_interval) == 0:
            test_txt.write(str(img_name) + '\n')
            yolo_test_img_path.write(img_path + '\n')
            test_num = test_num + 1
        elif total_processed_num % select_interval == 0:
            train_txt.write(str(img_name) + '\n')
            yolo_train_img_path.write(img_path+'\n')
            train_num = train_num + 1
        if total_processed_num % 100 == 0:
            print('completed {} / {}'.format(total_processed_num, xml_num))
    print('total:', total_processed_num, 'train:', train_num, 'test:', test_num)
    train_txt.close()
    test_txt.close()
    yolo_train_img_path.close()
    yolo_test_img_path.close()


if __name__ == '__main__':
    ROOT_DIR = '/zongshiban_data/2/ccj/dataset/plane_ship_tank_VOC'
    PARSER = argparse.ArgumentParser(description="define all the file paths")
    PARSER.add_argument("--target_dir", type=str, help="image directory",
                        default=os.path.join(ROOT_DIR, 'ImageSets', 'Main'))
    ARGS = PARSER.parse_args()
    TARGET_DIR = ARGS.target_dir
    if not os.path.exists(TARGET_DIR):
        os.makedirs(TARGET_DIR)

    split_train_test_set(ROOT_DIR, TARGET_DIR, 0.1)
